package linc.fun.openai.service.impl;

import cn.hutool.core.util.StrUtil;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.spring.service.impl.ServiceImpl;
import jakarta.annotation.Resource;
import linc.fun.openai.config.openai.OpenaiConfig;
import linc.fun.openai.domain.dto.query.ChatMessagePageQuery;
import linc.fun.openai.domain.dto.request.ChatProcessRequest;
import linc.fun.openai.domain.entity.chat.ChatMessageDO;
import linc.fun.openai.domain.entity.chat.ChatRoomDO;
import linc.fun.openai.domain.vo.ChatMessageVO;
import linc.fun.openai.enums.ApiTypeEnum;
import linc.fun.openai.enums.ChatMessageStatusEnum;
import linc.fun.openai.enums.ChatMessageTypeEnum;
import linc.fun.openai.exception.BizException;
import linc.fun.openai.handler.converter.ChatMessageConverter;
import linc.fun.openai.handler.emitter.*;
import linc.fun.openai.mapper.ChatMessageMapper;
import linc.fun.openai.service.ChatMessageService;
import linc.fun.openai.service.ChatRoomService;
import linc.fun.openai.util.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

import java.time.LocalDateTime;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;

import static linc.fun.openai.domain.entity.chat.table.Tables.*;

/**
 * @author linc
 * @date 2023-3-27
 * 聊天记录业务实现类
 */
@Slf4j
@Service
public class ChatMessageServiceImpl extends ServiceImpl<ChatMessageMapper, ChatMessageDO> implements ChatMessageService {

    @Resource
    private OpenaiConfig openaiConfig;
    @Resource
    private ChatRoomService chatRoomService;
    @Resource
    private IdGenerator idGenerator;

    @Override
    public Page<ChatMessageVO> pageChatMessage(ChatMessagePageQuery query) {
        QueryWrapper queryWrapper = QueryWrapper.create()
                // 聊天内容模糊查询
                .where(ChatMessage.Content.like(query.getContent()).when(StrUtil.isNotBlank(query.getContent())))
                // IP 模糊查询
                .and(ChatMessage.Ip.like(query.getIp()).when(StrUtil.isNotBlank(query.getIp())))
                // 查询指定聊天室
                .and(ChatMessage.ChatRoomId.eq(query.getChatRoomId()).when(Objects.nonNull(query.getChatRoomId())));
        Page<ChatMessageDO> page = this.page(new Page<>(query.getPageSize(), query.getPageNum()), queryWrapper);
        return PageUtil.toPage(page, ChatMessageConverter.INSTANCE.entityToVO(page.getRecords()));
    }

    @Override
    public ResponseBodyEmitter sendMessage(ChatProcessRequest chatProcessRequest) {
        // 超时时间设置 3 分钟
        ResponseBodyEmitter emitter = new ResponseBodyEmitter(3 * 60 * 1000L);
        emitter.onCompletion(() -> log.info("请求参数：{}，前端关闭emitter连接.", ObjectMapperUtil.toJson(chatProcessRequest)));
        emitter.onTimeout(() -> log.info("请求参数：{}，后端关闭emitter连接.", ObjectMapperUtil.toJson(chatProcessRequest)));

        // 构建 emitter 处理链路
        // 1.处理ip限流
        ResponseEmitterChain ipRateLimiterEmitterChain = new IpRateLimiterEmitterChain();
        ResponseEmitterChain sensitiveWordEmitterChain = new SensitiveWordEmitterChain();
        ChatMessageEmitterChain chatMessageEmitterChain = new ChatMessageEmitterChain();
        ChatUserUsageChain chatUserUsageChain = new ChatUserUsageChain();
        // 加一个用户使用次数的chain
        // 2.处理敏感词
        ipRateLimiterEmitterChain.setNext(sensitiveWordEmitterChain);
        // 3.用户使用chain
        sensitiveWordEmitterChain.setNext(chatUserUsageChain);
        // 4.处理消息
        chatUserUsageChain.setNext(chatMessageEmitterChain);
        ipRateLimiterEmitterChain.doChain(chatProcessRequest, emitter);
        return emitter;
    }

    @Transactional(rollbackFor = Exception.class)
    @Override
    public ChatMessageDO initChatMessage(ChatProcessRequest chatProcessRequest, ApiTypeEnum apiTypeEnum) {
        ChatMessageDO chatMessageDO = new ChatMessageDO();
        chatMessageDO.setId(idGenerator.getId());
        // 消息 id 手动生成
        chatMessageDO.setMessageId(UUID.randomUUID().toString());
        chatMessageDO.setMessageType(ChatMessageTypeEnum.QUESTION);
        chatMessageDO.setApiType(apiTypeEnum);
        if (apiTypeEnum == ApiTypeEnum.API_KEY) {
            chatMessageDO.setApiKey(openaiConfig.getApiKey());
        }
        chatMessageDO.setUserId(ChatUserUtil.getUserId());
        chatMessageDO.setContent(chatProcessRequest.getPrompt());
        chatMessageDO.setModelName(openaiConfig.getApiModel());
        chatMessageDO.setOriginalData(null);
        chatMessageDO.setPromptTokens(-1);
        chatMessageDO.setCompletionTokens(-1);
        chatMessageDO.setTotalTokens(-1);
        chatMessageDO.setIp(WebUtil.getIp());
        chatMessageDO.setStatus(ChatMessageStatusEnum.INIT);
        chatMessageDO.setCreateTime(LocalDateTime.now());
        chatMessageDO.setUpdateTime(LocalDateTime.now());

        // 填充初始化父级消息参数
        this.populateInitParentMessage(chatMessageDO, chatProcessRequest);
        this.save(chatMessageDO);
        return chatMessageDO;
    }

    /**
     * 填充初始化父级消息参数
     *
     * @param chatMessageDO      消息记录
     * @param chatProcessRequest 消息处理请求参数
     */
    private void populateInitParentMessage(ChatMessageDO chatMessageDO, ChatProcessRequest chatProcessRequest) {

        // 父级消息 id
        String parentMessageId = Optional.ofNullable(chatProcessRequest.getOptions()).map(ChatProcessRequest.Options::getParentMessageId).orElse(null);

        // 对话 id
        String conversationId = Optional.ofNullable(chatProcessRequest.getOptions()).map(ChatProcessRequest.Options::getConversationId).orElse(null);

        if (StrUtil.isAllNotBlank(parentMessageId, conversationId)) {

            // 寻找父级消息
            QueryWrapper queryWrapper = QueryWrapper.create()
                    // 用户 id 一致
                    .where(ChatMessage.UserId.eq(ChatUserUtil.getUserId()))
                    // 消息 id 一致
                    .and(ChatMessage.MessageId.eq(parentMessageId))
                    // 对话 id 一致
                    .and(ChatMessage.ConversationId.eq(conversationId))
                    // Api 类型一致
                    .and(ChatMessage.ApiType.eq(chatMessageDO.getApiType()))
                    // 消息类型为回答
                    .and(ChatMessage.MessageType.eq(ChatMessageTypeEnum.ANSWER));
            ChatMessageDO parentChatMessage = this.getOne(queryWrapper);

            if (Objects.isNull(parentChatMessage)) {
                throw BizException.PARENT_MESSAGE_NOT_EXIST;
            }

            chatMessageDO.setParentMessageId(parentMessageId);
            chatMessageDO.setParentAnswerMessageId(parentMessageId);
            chatMessageDO.setParentQuestionMessageId(parentChatMessage.getParentQuestionMessageId());
            chatMessageDO.setChatRoomId(parentChatMessage.getChatRoomId());
            chatMessageDO.setConversationId(parentChatMessage.getConversationId());
            chatMessageDO.setContextCount(parentChatMessage.getContextCount() + 1);
            chatMessageDO.setQuestionContextCount(parentChatMessage.getQuestionContextCount() + 1);

            if (chatMessageDO.getApiType() == ApiTypeEnum.ACCESS_TOKEN) {
                if (!Objects.equals(chatMessageDO.getModelName(), parentChatMessage.getModelName())) {
                    throw BizException.CURRENT_ACCESS_TOKEN_AND_MODEL_NOT_MATCH;
                }
            }

            // ApiKey 限制上下文问题的数量
            if (chatMessageDO.getApiType() == ApiTypeEnum.API_KEY && openaiConfig.getLimitQuestionContextCount() > 0 && chatMessageDO.getQuestionContextCount() > openaiConfig.getLimitQuestionContextCount()) {
                log.warn(StrUtil.format("当前允许连续对话的问题数量为[{}]次，已达到上限，请关闭上下文对话重新发送", openaiConfig.getLimitQuestionContextCount()));
                throw BizException.QUESTION_COUNT_HAS_LIMIT;
            }
        } else {
            // 创建新聊天室
            ChatRoomDO chatRoomDO = chatRoomService.createChatRoom(chatMessageDO);
            chatMessageDO.setChatRoomId(chatRoomDO.getId());
            chatMessageDO.setContextCount(1);
            chatMessageDO.setQuestionContextCount(1);
        }
    }
}
